Skip to content

Add Mixtral #2196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open

Add Mixtral #2196

wants to merge 18 commits into from

Conversation

kanpuriyanawab
Copy link
Collaborator

@kanpuriyanawab kanpuriyanawab commented Apr 2, 2025

This PR adds Mixtral to Keras Hub.

Reference

mixtral output matching

Screenshot 2025-04-20 at 3 06 15 PM Screenshot 2025-04-20 at 3 06 32 PM

@kanpuriyanawab kanpuriyanawab marked this pull request as ready for review April 10, 2025 08:40
@kanpuriyanawab
Copy link
Collaborator Author

Output matching :

image

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments! Please provide a demo colab

)
self._query_dense.build(inputs_shape)

self._key_dense = keras.layers.EinsumDense(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the layer names to be compatible with enable_lora

@keras_hub_export("keras_hub.models.MixtralBackbone")
class MixtralBackbone(Backbone):
"""
The Mixtral Transformer core architecture with hyperparameters.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring first line should follow """

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still needs to be changed to --> """The Mixtral Transformer core architecture with hyperparameters.

preprocessor("League of legends")

# Tokenize a batch of sentences.
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why tf?

target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1)

embeddings = None
with tf.GradientTape(watch_accessed_variables=True) as tape:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why tf?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

borrowed docstring

Screenshot 2025-04-16 at 7 18 14 PM

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't recommend using backend specific examples, For generic usage use keras.ops or numpy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some conflicts in the api directory due to the recent changes, please resolve.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conflicts resolved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't recommend using backend specific examples, For generic usage use keras.ops or numpy

@sachinprasadhs like I mentioned above, there is already tf.GradientTape examples in existing model docstrings, that should be cleaned up in a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not pile on the mess in new PRs. Lets keep it clean.

@kanpuriyanawab
Copy link
Collaborator Author

mixtral output matching

Screenshot 2025-04-20 at 3 06 15 PM Screenshot 2025-04-20 at 3 06 32 PM

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added few more comments.

from keras import ops


# TODO: Deprecate this in favor of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support Keras 2 anymore in Keras Hub, I guess you can get rid of this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to remove this comment, no, keras layernorm doesn't produce same results as this custom layernorm.

# Below is a workaround for `ops.triu` for Keras 2.
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is
# removed.
# causal_mask = ops.triu(causal_mask, k=-self.sliding_window)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keras 2 support is removed now, you can enable this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops.trui/tril has issues with dynamic shape on the tensorflow,
(refer keras_hub/src/models/gemma/gemma_attention.py/_mask_sliding_window),
hence I chose to keep this as it is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated comment tho!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, can you remove the line "# Below is a workaround for ops.triu for Keras 2."

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Left some small changes comments.

`tf.RaggedTensor` where the last dimension of the output is ragged.

If input is a scalar string (rank == 0), the layer will output a dense
`tf.Tensor` with static shape `[None]`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be corrected, since this is not specific to TF backend

# Below is a workaround for `ops.triu` for Keras 2.
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is
# removed.
# causal_mask = ops.triu(causal_mask, k=-self.sliding_window)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, can you remove the line "# Below is a workaround for ops.triu for Keras 2."

init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 16),
run_quantization_check=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you enable this test.

@keras_hub_export("keras_hub.models.MixtralBackbone")
class MixtralBackbone(Backbone):
"""
The Mixtral Transformer core architecture with hyperparameters.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still needs to be changed to --> """The Mixtral Transformer core architecture with hyperparameters.

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the aux_loss implementation for Mixtral?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants